{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "LANL Summer School Demo.ipynb",
      "private_outputs": true,
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/google/jax-md/blob/nvt_refactor/notebooks/lanl_summer_school_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "odlcm3Jmuglm",
        "cellView": "form"
      },
      "source": [
        "#@title Import & Util\n",
        "\n",
        "!pip install -q git+https://www.github.com/google/jax-md\n",
        "!pip install dm-haiku\n",
        "\n",
        "import jax.numpy as np\n",
        "from IPython.display import set_matplotlib_formats\n",
        "set_matplotlib_formats('pdf', 'svg')\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "import pickle\n",
        "\n",
        "import warnings\n",
        "warnings.simplefilter(\"ignore\")\n",
        "\n",
        "sns.set_style(style='white')\n",
        "background_color = [56 / 256] * 3\n",
        "def plot(x, y, *args):\n",
        "  plt.plot(x, y, *args, linewidth=3)\n",
        "  plt.gca().set_facecolor([1, 1, 1])\n",
        "def draw(R, **kwargs):\n",
        "  if 'c' not in kwargs:\n",
        "    kwargs['color'] = [1, 1, 0.9]\n",
        "  ax = plt.axes(xlim=(0, float(np.max(R[:, 0]))), \n",
        "                ylim=(0, float(np.max(R[:, 1]))))\n",
        "  ax.get_xaxis().set_visible(False)\n",
        "  ax.get_yaxis().set_visible(False)\n",
        "  ax.set_facecolor(background_color)\n",
        "  plt.scatter(R[:, 0], R[:, 1],  marker='o', s=1024, **kwargs)\n",
        "  plt.gcf().patch.set_facecolor(background_color)\n",
        "  plt.gcf().set_size_inches(6, 6)\n",
        "  plt.tight_layout()\n",
        "def draw_big(R, **kwargs):\n",
        "  if 'c' not in kwargs:\n",
        "    kwargs['color'] = [1, 1, 0.9]\n",
        "  fig = plt.figure(dpi=128)\n",
        "  ax = plt.axes(xlim=(0, float(np.max(R[:, 0]))),\n",
        "                ylim=(0, float(np.max(R[:, 1]))))\n",
        "  ax.get_xaxis().set_visible(False)\n",
        "  ax.get_yaxis().set_visible(False)\n",
        "  ax.set_facecolor(background_color)\n",
        "  s = plt.scatter(R[:, 0], R[:, 1], marker='o', s=0.5, **kwargs)\n",
        "  s.set_rasterized(True)\n",
        "  plt.gcf().patch.set_facecolor(background_color)\n",
        "  plt.gcf().set_size_inches(10, 10)\n",
        "  plt.tight_layout()\n",
        "def draw_displacement(R, dR):\n",
        "  plt.quiver(R[:, 0], R[:, 1], dR[:, 0], dR[:, 1], color=[1, 0.5, 0.5])\n",
        "\n",
        "# Progress Bars\n",
        "\n",
        "from IPython.display import HTML, display\n",
        "import time\n",
        "\n",
        "def ProgressIter(iter_fun, iter_len=0):\n",
        "  if not iter_len:\n",
        "    iter_len = len(iter_fun)\n",
        "  out = display(progress(0, iter_len), display_id=True)\n",
        "  for i, it in enumerate(iter_fun):\n",
        "    yield it\n",
        "    out.update(progress(i + 1, iter_len))\n",
        "\n",
        "def progress(value, max):\n",
        "    return HTML(\"\"\"\n",
        "        <progress\n",
        "            value='{value}'\n",
        "            max='{max}',\n",
        "            style='width: 45%'\n",
        "        >\n",
        "            {value}\n",
        "        </progress>\n",
        "    \"\"\".format(value=value, max=max))\n",
        "\n",
        "# Data Loading\n",
        "\n",
        "!wget -O silica_train.npz https://www.dropbox.com/s/3dojk4u4di774ve/silica_train.npz?dl=0\n",
        "!wget https://github.com/google/jax-md/blob/master/examples/models/si_gnn.pickle?raw=true\n",
        "\n",
        "import numpy as onp\n",
        "\n",
        "with open('silica_train.npz', 'rb') as f:\n",
        "  files = onp.load(f)\n",
        "  Rs, Es, Fs = [np.array(x) for x in (files['arr_0'], files['arr_1'], files['arr_2'])]\n",
        "  Rs = Rs[:200]\n",
        "  Es = Es[:200]\n",
        "  Fs = Fs[:200]\n",
        "  test_Rs, test_Es, test_Fs = [np.array(x) for x in (files['arr_3'], files['arr_4'], files['arr_5'])]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NutOCjdwpneq"
      },
      "source": [
        "## LANL Summer School Demo\n",
        "\n",
        "www.github.com/google/jax-md -> notebooks -> lanl_summer_school_demo.ipynb"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9_FfC_k1DMfB"
      },
      "source": [
        "### Energy and Automatic Differentiation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ISh1a8bK33QZ"
      },
      "source": [
        "$u(r) = \\begin{cases}\\frac13(1 - r)^3 & \\text{if $r < 1$} \\\\ 0 & \\text{otherwise} \\end{cases}$"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uQs4C7ggBxBM"
      },
      "source": [
        "import jax.numpy as np\n",
        "\n",
        "def soft_sphere(r):\n",
        "  return np.where(r < 1, \n",
        "                  1/3 * (1 - r) ** 3,\n",
        "                  0.)\n",
        "\n",
        "print(soft_sphere(0.5))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gfEim-vuuXKd"
      },
      "source": [
        "r = np.linspace(0, 2., 200)\n",
        "plot(r, soft_sphere(r))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Eqt0U4tArwlE"
      },
      "source": [
        "We can compute its derivative automatically"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QrOivTPpCJV7"
      },
      "source": [
        "from jax.api import grad\n",
        "\n",
        "du_dr = grad(soft_sphere)\n",
        "\n",
        "print(du_dr(0.5))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ws8igUYu20tK"
      },
      "source": [
        "We can vectorize the derivative computation over many radii"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C3sx2gTb23s7"
      },
      "source": [
        "from jax.api import vmap\n",
        "\n",
        "du_dr_v = vmap(du_dr)\n",
        "\n",
        "plot(r, soft_sphere(r))\n",
        "plot(r, -du_dr_v(r))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b6uiqzqS04EN"
      },
      "source": [
        "### Randomly Initialize a System"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EmDIRcGvBaN6"
      },
      "source": [
        "from jax import random\n",
        "\n",
        "key = random.PRNGKey(0)\n",
        "\n",
        "particle_count = 128\n",
        "dim = 2"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BXOafNfpppzn"
      },
      "source": [
        "from jax_md.quantity import box_size_at_number_density\n",
        "\n",
        "# number_density = N / V\n",
        "side_length = box_size_at_number_density(particle_count = particle_count, \n",
        "                                         number_density = 1.0, \n",
        "                                         spatial_dimension = dim)\n",
        "\n",
        "R = random.uniform(key, (particle_count, dim), maxval=side_length)\n",
        "draw(R)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "APxYUxxo7Dzg"
      },
      "source": [
        "### Displacements and Distances\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "p9fNv_PXCWg3"
      },
      "source": [
        "from jax_md import space\n",
        "\n",
        "displacement, shift = space.periodic(side_length)\n",
        "\n",
        "print(displacement(R[0], R[1]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ObwSolXq7R6D"
      },
      "source": [
        "metric = space.metric(displacement)\n",
        "\n",
        "print(metric(R[0], R[1]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0bmwFmj57HdI"
      },
      "source": [
        "Compute distances between pairs of points"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EX--_0MoDV1m"
      },
      "source": [
        "displacement = space.map_product(displacement)\n",
        "metric = space.map_product(metric)\n",
        "\n",
        "print(metric(R[:3], R[:3]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MXxXuAja7W3-"
      },
      "source": [
        "### Total energy of a system"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Tl6HLZwfEWcz"
      },
      "source": [
        "def energy(R):\n",
        "  dr = metric(R, R)\n",
        "  return 0.5 * np.sum(soft_sphere(dr))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wZrWkqQkEs35"
      },
      "source": [
        "print(energy(R))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Hp41YZCmEtb6"
      },
      "source": [
        "print(grad(energy)(R).shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5mVlh-lr5FY0"
      },
      "source": [
        "### Minimization"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vs96hWLR5RxV"
      },
      "source": [
        "from jax_md.minimize import fire_descent\n",
        "\n",
        "init_fn, apply_fn = fire_descent(energy, shift)\n",
        "\n",
        "state = init_fn(R)\n",
        "\n",
        "while np.max(np.abs(state.force)) > 1e-3:\n",
        "  state = apply_fn(state)\n",
        "\n",
        "draw(state.position)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c1V8ds1fe7oa"
      },
      "source": [
        "cond_fn = lambda state: np.max(np.abs(state.force)) > 1e-3"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rAqamj4fse3W"
      },
      "source": [
        "### Making it Fast"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9IJL_maT6qUM"
      },
      "source": [
        "def minimize(R):\n",
        "  init, apply = fire_descent(energy, shift)\n",
        "\n",
        "  state = init(R)\n",
        "\n",
        "  for _ in range(20):\n",
        "    state = apply(state)\n",
        "\n",
        "  return energy(state.position)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6ohMuvT38XcP"
      },
      "source": [
        "%%timeit\n",
        "minimize(R).block_until_ready()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZNMxdujG81-6"
      },
      "source": [
        "from jax.api import jit\n",
        "\n",
        "# Just-In-Time compile to GPU\n",
        "minimize = jit(minimize)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gDGx3UAE9CIn"
      },
      "source": [
        "# The first call incurs a compilation cost\n",
        "minimize(R)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zOr9jtW48w-s"
      },
      "source": [
        "%%timeit\n",
        "minimize(R).block_until_ready()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZFBLKVTDZN_z"
      },
      "source": [
        "from jax.lax import while_loop\n",
        "\n",
        "def minimize(R):\n",
        "  init_fn, apply_fn = fire_descent(energy, shift)\n",
        "\n",
        "  state = init_fn(R)\n",
        "  # Using a JAX loop reduces compilation cost\n",
        "  state = while_loop(cond_fun=cond_fn,\n",
        "                     body_fun=apply_fn,\n",
        "                     init_val=state)\n",
        "\n",
        "  return state.position"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9Jesy9PRZc62"
      },
      "source": [
        "from jax.api import jit\n",
        "\n",
        "minimize = jit(minimize)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FkfvXICRZd3Z"
      },
      "source": [
        "R_is = minimize(R)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1yHZHshlZeVw"
      },
      "source": [
        "%%timeit\n",
        "minimize(R).block_until_ready()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6DUSO3go4b41"
      },
      "source": [
        "### Elastic Moduli"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RfTVgGXG49fC"
      },
      "source": [
        "def strain_energy(strain, R):\n",
        "  dR = np.dot(displacement(R, R), strain)\n",
        "  dr = space.distance(dR)\n",
        "  return 0.5 * np.sum(soft_sphere(dr))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dV5s_OSB5iUS"
      },
      "source": [
        "strain_energy(np.eye(2),  R_is)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1oAkt_2K8Lwv"
      },
      "source": [
        "from jax.api import hessian\n",
        "\n",
        "K = hessian(strain_energy)(np.eye(2),  R_is)\n",
        "print(K.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6PtJhYbYzGIk"
      },
      "source": [
        "print(K)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "l3oJRO__8ZyS"
      },
      "source": [
        "from jax_md.quantity import bulk_modulus\n",
        "\n",
        "B = bulk_modulus(K)\n",
        "print(B)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cQHLyhK19Nzh"
      },
      "source": [
        "G = K[0, 1, 0, 1]\n",
        "print(G)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1Ae0oeOdp-FV"
      },
      "source": [
        "from functools import partial\n",
        "\n",
        "@jit\n",
        "def elastic_moduli(number_density, key):\n",
        "  # Randomly initialize particles.\n",
        "  side_length = box_size_at_number_density(particle_count    = particle_count, \n",
        "                                           number_density    = number_density, \n",
        "                                           spatial_dimension = dim)\n",
        "  R = random.uniform(key, (particle_count, dim))\n",
        "\n",
        "  displacement, shift = space.periodic(side_length)\n",
        "  displacement = space.map_product(displacement)\n",
        "\n",
        "  # Define an energy function at a specific strain.\n",
        "  def energy(strain, R):\n",
        "    dR = displacement(R, R) @ strain\n",
        "    dr = space.distance(dR)\n",
        "    return 0.5 * np.sum(soft_sphere(dr))\n",
        "\n",
        "  # Minimize at no strain.\n",
        "  init_fn, apply_fn = fire_descent(partial(energy, np.eye(2)), shift)\n",
        "\n",
        "  state = init_fn(R)\n",
        "  state = while_loop(cond_fn, apply_fn, state)\n",
        "\n",
        "  # Compute the elastic constants.\n",
        "  K = hessian(energy)(np.eye(2), state.position)\n",
        "  return bulk_modulus(K), K[0, 1, 0, 1]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-0zODBrXq-jd"
      },
      "source": [
        "number_densities = np.linspace(1.1, 1.6, 40)\n",
        "\n",
        "elastic_moduli = vmap(elastic_moduli, in_axes=(0, None))\n",
        "B, G = elastic_moduli(number_densities, key)\n",
        "\n",
        "plot(number_densities, B)\n",
        "plot(number_densities, G)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Eu0p8bFfPJJ4"
      },
      "source": [
        "plot(number_densities, G / B)\n",
        "plt.ylim([0, 1])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XZUNCbosLlf3"
      },
      "source": [
        "keys = random.split(key, 5)\n",
        "\n",
        "elastic_moduli = vmap(elastic_moduli, in_axes=(None, 0))\n",
        "B, G = elastic_moduli(number_densities, keys)\n",
        "\n",
        "for b, g in zip(B, G):\n",
        "  plt.plot(number_densities, g / b)\n",
        "\n",
        "plot(number_densities, np.mean(G / B, axis=0), 'k')\n",
        "plt.ylim([0, 1])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EeeS0GBb06US"
      },
      "source": [
        "### Going Big"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0mqD5mjY08ZC"
      },
      "source": [
        "key = random.PRNGKey(0)\n",
        "\n",
        "particle_count = 128000\n",
        "side_length = box_size_at_number_density(particle_count    = particle_count, \n",
        "                                         number_density    = 1.0, \n",
        "                                         spatial_dimension = dim)\n",
        "\n",
        "\n",
        "R = random.uniform(key, (particle_count, dim)) * side_length\n",
        "\n",
        "displacement, shift = space.periodic(side_length)\n",
        "\n",
        "draw_big(R)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7glVkdK31ZqG"
      },
      "source": [
        "from jax_md.energy import soft_sphere_neighbor_list\n",
        "\n",
        "neighbor_fn, energy_fn = soft_sphere_neighbor_list(displacement, side_length)\n",
        "\n",
        "init_fn, apply_fn = fire_descent(energy_fn, shift)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HLpaev18txCG"
      },
      "source": [
        "nbrs = neighbor_fn(R)\n",
        "print(nbrs.idx.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "G8cjvj0X2VLG"
      },
      "source": [
        "from jax.lax import fori_loop\n",
        "\n",
        "state = init_fn(R, neighbor=nbrs)\n",
        "\n",
        "@jit\n",
        "def take_steps(state_and_nbrs):\n",
        "  def step(i, state_and_nbrs):\n",
        "    state, nbrs = state_and_nbrs\n",
        "    nbrs = neighbor_fn(state.position, nbrs)\n",
        "    state = apply_fn(state, neighbor=nbrs)\n",
        "    return state, nbrs\n",
        "  return fori_loop(0, 10, step, state_and_nbrs)\n",
        "\n",
        "for i in ProgressIter(range(80)):\n",
        "  new_state, nbrs = take_steps((state, nbrs))\n",
        "\n",
        "  if nbrs.did_buffer_overflow:\n",
        "    nbrs = neighbor_fn(state.position)\n",
        "  else:\n",
        "    state = new_state\n",
        "\n",
        "draw_big(state.position)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "89DHZ6cMu3zd"
      },
      "source": [
        "## Neural Network Potentials"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x6rF6EEJc3P1"
      },
      "source": [
        "Here is some data we loaded of a 64-atom Silicon system computed using DFT."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "B9Ylk3d3c26h"
      },
      "source": [
        "print(Rs.shape)  # Positions\n",
        "print(Es.shape)  # Energies\n",
        "print(Fs.shape)  # Forces"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "H27hVVxO3foZ"
      },
      "source": [
        "E_mean = np.mean(Es)\n",
        "E_std = np.std(Es)\n",
        "\n",
        "print(f'E_mean = {E_mean}, E_std = {E_std}')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PEP6AHHExork"
      },
      "source": [
        "plt.hist(Es)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XJOnNJ0Sxx-F"
      },
      "source": [
        "Setup the system and a Graph Neural Network energy function"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "naOzXLguq7Ya"
      },
      "source": [
        "box_size = 10.862\n",
        "displacement, shift = space.periodic(box_size)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "thfUg3VjcNxY"
      },
      "source": [
        "from jax_md.energy import graph_network\n",
        "\n",
        "init_fn, energy_fn = graph_network(displacement, r_cutoff=3.0)\n",
        "\n",
        "with open('si_gnn.pickle?raw=true', 'rb') as f:\n",
        "  params = pickle.load(f)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NCmw_Ub_cOyG"
      },
      "source": [
        "energy_fn(params, test_Rs[0])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EECLeW1HwP9f"
      },
      "source": [
        "@jit\n",
        "def scaled_energy_fn(R, **kwargs):\n",
        "  return energy_fn(params, R) * E_std + E_mean"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CNH3gBJCiSXr"
      },
      "source": [
        "predicted_Es = vmap(scaled_energy_fn)(Rs)\n",
        "plt.plot(Es, predicted_Es, 'o')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6FVkIOSI3p0N"
      },
      "source": [
        "from jax_md.quantity import force\n",
        "\n",
        "force_fn = force(scaled_energy_fn)\n",
        "predicted_Fs = force_fn(Rs[1])\n",
        "\n",
        "plt.plot(Fs[1].reshape((-1,)), predicted_Fs.reshape((-1,)), 'o')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qnsJn7AljvRY"
      },
      "source": [
        "This energy can be used in a simulation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SnT7hj_w4CmB"
      },
      "source": [
        "from jax_md.simulate import nvt_nose_hoover\n",
        "from jax_md.quantity import temperature\n",
        "\n",
        "K_B = 8.617e-5\n",
        "dt = 1e-3\n",
        "kT = K_B * 300 \n",
        "Si_mass = 2.91086E-3\n",
        "\n",
        "init_fn, apply_fn = nvt_nose_hoover(scaled_energy_fn, shift, dt, kT)\n",
        "\n",
        "apply_fn = jit(apply_fn)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z3nbOc6ikkNq"
      },
      "source": [
        "state = init_fn(key, Rs[0], Si_mass)\n",
        "\n",
        "@jit\n",
        "def take_steps(state):\n",
        "  return fori_loop(0, 1000, lambda i, state: apply_fn(state), state)\n",
        "\n",
        "print('Energy (eV)\\tTemperature (K)')\n",
        "for i in range(10):\n",
        "  state = take_steps(state)\n",
        "\n",
        "  print('{:.02f}\\t\\t\\t{:.02f}'.format(\n",
        "      scaled_energy_fn(state.position),\n",
        "      temperature(state.velocity, Si_mass) / K_B))"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}